package io.sqrtqiezi.spark.lagou

import org.apache.spark.sql.SparkSession

object IPAnalysis {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .master("local[*]")
      .appName(this.getClass.getCanonicalName.init)
      .getOrCreate()

    // 1.0.1.0|1.0.3.255|16777472|16778239|亚洲|中国|福建|福州||电信|350100|China|CN|119.306239|26.075302
    val ipRDD = spark.sparkContext.textFile("lagou-data/ip.dat")
      .map(line => line.split("\\|"))
      .map(item => (ip2int(item(0)), ip2int(item(1)), item(6) + item(7)))
      .sortBy(_._1)

    // 创建广播表
    val ipTable = spark.sparkContext.broadcast(ipRDD.collect)

    //格式：时间戳、IP地址、访问网址、访问数据、浏览器信息
    val accessLog = spark.sparkContext.textFile("lagou-data/http.log")
      .map(line => line.split("\\|"))
      .map(item => ip2int(item(1))) // 只保留记录访问的 IP

    val result = accessLog
      .map(ip => ipTable.value
        .find { case (start, end, _) => ip >= start && ip <= end } // 根据 IP 范围查找
        .getOrElse((0,0,"unknown"))
        ._3
      )
      .countByValue

    for ((name, count) <- result) {
      println(s"$name : $count")
    }

    //    河北石家庄 : 383
    //    重庆重庆 : 868
    //    云南昆明 : 126
    //    北京北京 : 1535
    //    陕西西安 : 1824
    spark.stop()
  }

  // convert ip string to int
  // sample：assert(ip2int("36.193.58.0") == 616643072)
  def ip2int(ip: String): Int = {
    val terms = ip.split("\\.")
      .map(_.toInt)
    terms(0) << 24 | terms(1) << 16 | terms(2) << 8 | terms(3)
  }
}
